{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# NLP Clustering for Modal Spaces\n",
    "Author: Eli Hecht. Adapted from Alina Dracheva's code\n",
    "Purpose: Compute text embeddings and clusters for possibility generation and decision responses."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "os.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\"\n",
    "\n",
    "import pandas as pd\n",
    "\n",
    "\n",
    "from sentence_transformers import SentenceTransformer #load the model\n",
    "model = SentenceTransformer('sentence-transformers/all-mpnet-base-v2')\n",
    "\n",
    "import umap\n",
    "from sklearn.cluster import KMeans\n",
    "from scipy.spatial import distance_matrix\n",
    "from sklearn.cluster import DBSCAN\n",
    "# from sklearn.metrics import silhouette_score\n",
    "import nltk\n",
    "from nltk.tokenize import word_tokenize\n",
    "from nltk.corpus import stopwords\n",
    "\n",
    "import numpy as np\n",
    "import plotly.graph_objects as go\n",
    "import plotly.express as px\n",
    "# import kaleido\n",
    "import plotly.io as pio\n",
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Define analysis functions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Function to filter specified words  and pronouns\n",
    "def filter_words(sentence, words_to_remove):\n",
    "    # Tokenize\n",
    "    words = word_tokenize(sentence)\n",
    "    \n",
    "    # Tag each word with its part of speech\n",
    "    pos_tags = nltk.pos_tag(words)\n",
    "    \n",
    "    # Define the pos tags for personal words\n",
    "    pronoun_tags = {'PRP', 'PRP$'}\n",
    "    \n",
    "    # list of words directly specified the user to remove\n",
    "    words_to_remove_lower = [noun.lower() for noun in words_to_remove]\n",
    "\n",
    "    #Filter out words that are personal words\n",
    "    filtered_words = [word for word, tag in pos_tags if tag not in pronoun_tags and word.lower() not in words_to_remove_lower]\n",
    "    \n",
    "    #Reassemble the sentence\n",
    "    return ' '.join(filtered_words)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Function to compute text embeddings for a single context\n",
    "def compute_embeddings(df, text_column, context_number, words_to_remove):\n",
    "    # select only responses from specified context\n",
    "    texts_df = df[df['context'] == context_number] \n",
    "    texts_df = texts_df.reset_index()\n",
    "\n",
    "    # Convert all responses to strings\n",
    "    texts_df[text_column] = texts_df[text_column].astype(str)\n",
    "\n",
    "    # Strips text of specified words and pronouns\n",
    "    texts_df[text_column] = texts_df[text_column].apply(lambda x: filter_words(x, words_to_remove))\n",
    "\n",
    "    texts = texts_df[text_column].tolist()\n",
    "\n",
    "    # compute option embeddings for each response\n",
    "    text_embeddings = model.encode(texts, show_progress_bar=True) \n",
    "    return text_embeddings, texts_df\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# dimensionality reduction with pre-defined parameters\n",
    "def reduce_dimensions(texts_df, text_embeddings, umap_params):\n",
    "    umap_embeddings = (umap.UMAP(n_neighbors=umap_params['n_neighbors'], \n",
    "                                    n_components=umap_params['n_components'], \n",
    "                                    metric=umap_params['metric'],\n",
    "                                    min_dist=umap_params['min_dist'],\n",
    "                                    random_state=umap_params['random_state'])\n",
    "                                .fit_transform(text_embeddings))\n",
    "\n",
    "    # Split the UMAP embeddings into individual columns for easier processing later\n",
    "    for i in range(umap_embeddings.shape[1]):\n",
    "        texts_df[f'umap_dim_{i}'] = umap_embeddings[:, i]\n",
    "\n",
    "    # Convert the UMAP embeddings from individual columns to lists for use in the centroid distance calculation\n",
    "    umap_list = [f'umap_dim_{i}' for i in range(umap_embeddings.shape[1])]\n",
    "    texts_df['umap_embedding_list'] = texts_df[umap_list].apply(lambda row: row.tolist(), axis=1)\n",
    "    \n",
    "    # dimensionality reduction to only two dimensions\n",
    "    # this is used for plotting clusters later\n",
    "    umap_embeddings_2 = (umap.UMAP(n_neighbors=umap_params['n_neighbors'], \n",
    "                                    n_components=2, # only 2 dimensions\n",
    "                                    metric=umap_params['metric'],\n",
    "                                    min_dist=umap_params['min_dist'],\n",
    "                                    random_state=umap_params['random_state'])\n",
    "                                .fit_transform(text_embeddings))\n",
    "    \n",
    "    # Split the UMAP embeddings into individual columns\n",
    "    for i in range(umap_embeddings_2.shape[1]):\n",
    "        texts_df[f'umap_dim2_{i}'] = umap_embeddings_2[:, i]\n",
    "    \n",
    "    return umap_embeddings"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# computes clusters on reduced embeddings\n",
    "def compute_clusters(umap_embeddings, clustering_params):\n",
    "    if clustering_params['algorithm_name'] == 'KMeans':\n",
    "        # initialize KMeans model\n",
    "        clustering_model = KMeans(n_clusters=clustering_params['num_clusters'], n_init='auto')\n",
    "    elif clustering_params['algorithm_name'] == 'DBSCAN':\n",
    "        # initialize DBSCAN model\n",
    "        clustering_model = DBSCAN(eps=clustering_params['eps'], min_samples=clustering_params['min_samples'])\n",
    "    else:\n",
    "        raise ValueError(\"Invalid clustering method. Supported methods are 'kmeans' and 'dbscan'.\")\n",
    "    \n",
    "    # perform clustering using specified paramaters\n",
    "    clustering_model.fit(umap_embeddings)\n",
    "\n",
    "    # labels includes labels of each point, clustering_model\n",
    "    return clustering_model.labels_, clustering_model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# function to print clusters for row by row evaluation\n",
    "def print_clusters(texts_df, text_column_name, cluster_assignment, num_clusters):\n",
    "    clustered_sentences = [[] for i in range(num_clusters)]\n",
    "    for sentence_id, cluster_id in enumerate(cluster_assignment):\n",
    "        clustered_sentences[cluster_id].append(texts_df[text_column_name][sentence_id])\n",
    "\n",
    "    for i, cluster in enumerate(clustered_sentences):\n",
    "        print(\"Cluster \", i+1)\n",
    "        cluster_list = []\n",
    "        for item in cluster:\n",
    "            # removes identical context that is included in text of each result\n",
    "            cluster_list.append(item.split(\":\", 1)[-1].strip())\n",
    "        print(cluster_list)\n",
    "        print(\"\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# function to create elbow-plot to identify ideal number of clusters\n",
    "def k_plot(text_embeddings, umap_params, directory):\n",
    "    distortions = []\n",
    "    K_range = range(1, 18)\n",
    "\n",
    "\n",
    "    umap_embeddings = (umap.UMAP(n_neighbors=umap_params['n_neighbors'], \n",
    "                                        n_components=umap_params['n_components'], \n",
    "                                        metric=umap_params['metric'],\n",
    "                                        min_dist=umap_params['min_dist'],\n",
    "                                        random_state=umap_params['random_state'])\n",
    "                                    .fit_transform(text_embeddings))\n",
    "    \n",
    "    # perform KMeans for each k value\n",
    "    for k in K_range:\n",
    "        clustering_model = KMeans(n_clusters=k, n_init='auto')\n",
    "        clustering_model.fit(umap_embeddings)\n",
    "        distortions.append(clustering_model.inertia_) # appends inertia to distortians list\n",
    "\n",
    "    # Plotting the elbow curve\n",
    "    plt.plot(K_range, distortions, marker='o')\n",
    "    plt.title('Elbow Method For Optimal k')\n",
    "    plt.xlabel('Number of Clusters (k)')\n",
    "    plt.ylabel('Sum of Squared Distances')\n",
    "    # saves figure to specified folder\n",
    "    plt.savefig(directory)\n",
    "    # plt.show()\n",
    "    plt.clf()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# create a df with clusters, centroids and umap embeddings\n",
    "def compute_centroid(texts_df, text_column_name, original_text_column_name, clustering_params, cluster_assignment, clustering_model):\n",
    "    # attach the cluster assignments to the dataframe\n",
    "    texts_df['cluster'] = pd.Series(cluster_assignment, index=texts_df.index)\n",
    "    \n",
    "    # DBSCAN has no applicable concept of centroids so this is just performed on clusters generated by kMeans\n",
    "    if clustering_params['algorithm_name']=='KMeans':\n",
    "        # Attach the centroids to the dataframe\n",
    "        # In sklearn, the cluster centers are available directly via clustering_model.cluster_centers_\n",
    "        texts_df['centroid'] = texts_df['cluster'].apply(lambda x: clustering_model.cluster_centers_[x])\n",
    "        # Define a function to compute the distance of each embedding from its cluster's centroid\n",
    "        def distance_from_centroid(row):\n",
    "            return distance_matrix([row['umap_embedding_list']], [row['centroid']])[0][0]\n",
    "        texts_df['distance_from_centroid'] = texts_df.apply(distance_from_centroid, axis=1)\n",
    "    else:\n",
    "        texts_df['distance_from_centroid'] = np.nan\n",
    "\n",
    "    # Select the response closest to each cluster centroid to serve as a summary of the cluster\n",
    "    # For DBSCAN the response selected will be whichever happens to be first, but this will still be useful for later naming clusters\n",
    "    summary = texts_df.sort_values('distance_from_centroid', ascending=True).groupby('cluster').head(1).sort_index()[text_column_name].tolist()\n",
    "\n",
    "    # Create a dictionary linking each summary response to its corresponding cluster number\n",
    "    clusters = {}\n",
    "    for i in range(len(summary)):\n",
    "        clusters[summary[i]] = texts_df.loc[texts_df[text_column_name] == summary[i], \"cluster\"].iloc[0]\n",
    "\n",
    "    # create a dictionary with cluster numbers as keys and their centroid texts as values\n",
    "    id_to_name = {v: k for k, v in clusters.items()}\n",
    "\n",
    "    texts_df['cluster_name'] = texts_df['cluster'].map(id_to_name) # create a column with centroid texts\n",
    "\n",
    "    # Create a dictionary with cluster names as keys and original response names of centroids as values\n",
    "    id_to_original_response = {}\n",
    "    for cluster_name, centroid_text in id_to_name.items():\n",
    "        original_response = texts_df.loc[texts_df[text_column_name] == centroid_text, original_text_column_name].iloc[0]\n",
    "        id_to_original_response[cluster_name] = original_response\n",
    "\n",
    "    # Map original response names of centroids to cluster names\n",
    "    texts_df['cluster_name'] = texts_df['cluster'].map(id_to_original_response)\n",
    "\n",
    "    return clusters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Function to plot clusters on the joint embedding space\n",
    "\n",
    "def plot_clusters(component1, component2, cluster, name, response,\n",
    "                   data_source, \n",
    "                   umap_params, clustering_params,\n",
    "                   plot_location):\n",
    "    pio.renderers.default = \"browser\"\n",
    "    color_palette = px.colors.qualitative.Light24\n",
    "    \n",
    "    fig = go.Figure()\n",
    "    \n",
    "    \n",
    "    title_str = f\"UMAP Parameters: n_neighbors={umap_params['n_neighbors']}, n_components={umap_params['n_components']}, min_dist={umap_params['min_dist']} | Clustering: {clustering_params['algorithm_name']} with k={clustering_params['num_clusters']} clusters\"\n",
    "    if clustering_params['algorithm_name']=='DBSCAN':\n",
    "        title_str += f\" at EPS={clustering_params['eps']} and min_samples={clustering_params['min_samples']}\"\n",
    "    \n",
    "    # Get unique clusters and data sources\n",
    "    unique_clusters = sorted(cluster.unique())\n",
    "    unique_data_sources = sorted(data_source.unique())\n",
    "\n",
    "    color_map = {uc: color_palette[i % len(color_palette)] for i, uc in enumerate(unique_clusters)}\n",
    "    if len(unique_data_sources) == 1:\n",
    "        marker_symbols = {\n",
    "            unique_data_sources[0]: 'square'\n",
    "        }\n",
    "    if len(unique_data_sources) == 2:\n",
    "        marker_symbols = {\n",
    "            unique_data_sources[1]: 'square',\n",
    "            unique_data_sources[0]: 'diamond'\n",
    "        }\n",
    "    \n",
    "    \n",
    "    # Add a trace for each Source to indicate the shape in the legend\n",
    "    for ds, symbol in marker_symbols.items():\n",
    "        fig.add_trace(go.Scatter(\n",
    "            x=[None],\n",
    "            y=[None],\n",
    "            mode='markers',\n",
    "            marker=dict(\n",
    "                size=10,\n",
    "                symbol=symbol,\n",
    "                # color=trajectory_colors[ds]\n",
    "            ),\n",
    "            name=f'{ds.capitalize()} (shape)'\n",
    "        ))\n",
    "\n",
    "    added_cluster_names = set()\n",
    "    for ds in unique_data_sources:\n",
    "        for uc in unique_clusters:\n",
    "            mask = (cluster == uc) & (data_source == ds)\n",
    "            if mask.any():  # Check if there are any rows after applying the mask\n",
    "                show_in_legend = name[mask].iloc[0] not in added_cluster_names\n",
    "                added_cluster_names.add(name[mask].iloc[0])\n",
    "                \n",
    "                fig.add_trace(go.Scatter(\n",
    "                    x=component1[mask],\n",
    "                    y=component2[mask],\n",
    "                    mode='text+markers',\n",
    "                    name=name[mask].iloc[0] if show_in_legend else None,\n",
    "                    legendgroup=f'group{uc}',\n",
    "                    showlegend=show_in_legend,\n",
    "                    hovertext= str(uc) + \": \" + response[mask] ,\n",
    "                    # text='gen_num[mask].astype(str)',\n",
    "                    marker=dict(\n",
    "                        size=12,\n",
    "                        color=color_map[uc],\n",
    "                        symbol=marker_symbols[ds],  # Use the marker symbol based on the data source\n",
    "                        line_width=1,\n",
    "                        opacity=1\n",
    "                    ),\n",
    "                    textfont=dict(\n",
    "                        size=10,\n",
    "                        color='black'\n",
    "                    )\n",
    "                ))\n",
    "\n",
    "\n",
    "    fig.update_layout(\n",
    "        margin=dict(l=100, r=100, b=100, t=100),\n",
    "        width=2000,\n",
    "        height=1200,\n",
    "        showlegend=True,\n",
    "        title=title_str,\n",
    "        paper_bgcolor='white',  # White background for the entire plot area\n",
    "        plot_bgcolor='white',\n",
    "        legend=dict(\n",
    "            yanchor=\"top\",\n",
    "            y=1,\n",
    "            xanchor=\"left\",\n",
    "            x=0.01,\n",
    "            bgcolor='rgba(255,255,255,0)'\n",
    "        )\n",
    "    )\n",
    "\n",
    "    fig.layout.template = 'ggplot2'\n",
    "    fig.write_html(plot_location)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load in data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "agent_list = ['Heinz', 'Josh', 'Brian', 'Liz', 'Mary', 'Brad', 'Darya', 'Eunice', 'Eamon', 'Cameron', 'Erica', 'Carl', 'Daniel', 'Andy', 'Ahmed', 'Eva', 'Jeff', 'Shania']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# load first-person decision study data\n",
    "df_decision = pd.read_csv('../data/decision.csv')\n",
    "\n",
    "# add id column\n",
    "df_decision = df_decision.reset_index().rename(columns={'index': 'id'})\n",
    "df_decision['id'] += 1\n",
    "\n",
    "# select only participants who finished\n",
    "df_decision = df_decision[df_decision['finished']]\n",
    "\n",
    "# exclude ids of participants who gave non-sensical responses\n",
    "exclude_ids = [21, 64, 72, 74, 84, 86, 89]\n",
    "df_decision[~df_decision['id'].isin(exclude_ids)].reset_index(drop=True) \n",
    "\n",
    "# Select columns 'id', 'S1_1' to 'S18_1'\n",
    "df_decision = df_decision[['id', 'S1_1', 'S2_1', 'S3_1', 'S4_1', 'S5_1', 'S6_1', 'S7_1', 'S8_1', 'S9_1', 'S10_1', 'S11_1', 'S12_1', 'S13_1', 'S14_1', 'S15_1', 'S16_1', 'S17_1', 'S18_1']]\n",
    "\n",
    "#  Melt the DataFrame to long format\n",
    "df_decision = pd.melt(df_decision, id_vars=['id'], var_name='context', value_name='decision')\n",
    "\n",
    "# Extract numeric values from 'context' column using str.extract\n",
    "df_decision['context'] = df_decision['context'].str.extract('(\\d+)').astype(int)\n",
    "\n",
    "# drop empty responses\n",
    "df_decision.dropna(subset=['decision'], inplace=True)\n",
    "df_decision.rename(columns={'decision': 'response'}, inplace=True)\n",
    "\n",
    "# add source column indicating that this is from the decision study\n",
    "df_decision['source'] = 'decision'\n",
    "\n",
    "# df_decision"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# load third-person possibility generation study data\n",
    "df_pg = pd.read_csv('../manualCoding/pg_coded_final.csv', index_col=0)\n",
    "df_pg = df_pg[['context', 'id', 'answer', 'text', 'value']]\n",
    "df_pg.rename(columns={\"text\":\"response\"}, inplace=True)\n",
    "df_pg['source'] = 'pg'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Merge decision study data and possibility generation study data into one data frame for clustering\n",
    "df = pd.merge(df_decision, df_pg, how='outer')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Add full scenario texts to merged_text to give LLM context for responses\n",
    "contexts_list = pd.read_csv('../materials/contextsTable.csv', index_col=0)['text']\n",
    "\n",
    "df_merge = pd.merge(df, contexts_list, left_on='context', right_index=True)\n",
    "df_merge['merged_text'] = df_merge['text'] + ' : ' + df_merge['response']\n",
    "df_merge.rename(columns={\"response\": \"response_original\", \"text\": \"scenario_text\"}, inplace=True)\n",
    "\n",
    "df = df_merge"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Compute embeddings and Analyze clusters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "### To run the analyses edit the values in this cell then run this cell and the one below\n",
    "\n",
    "\n",
    "# directory to send resulting plots and clusters\n",
    "dir = \"pg_decision_clusters/\"\n",
    "\n",
    "## data sources list: \"pg\" (just Study 1 possibility generation data), \"decision\" (just Study 2decision data)\n",
    "## For convergence of participant responses, just \"pg\" should be in sources.\n",
    "## For clustering to model decision likelihood both \"pg\" and \"decision\" should be included in sources.\n",
    "sources = [\"pg\", \"decision\"]\n",
    "df_clustering = df[df['source'].isin(sources)]\n",
    "\n",
    "## text_column_name is the column the embeddings will be performed on\n",
    "## text_column_name options: 'response_original' or 'merged_text'\n",
    "# merged_text gives context by merging response with the scenario text and generally gives better results\n",
    "text_column_name = 'merged_text'\n",
    "original_text_column_name = 'response_original' # this is used so you can keep track of original texts for labelling\n",
    "\n",
    "\n",
    "# Define UMAP and clustering parameters\n",
    "umap_params = {'n_neighbors': 100, 'n_components': 10, 'metric': 'cosine', 'min_dist': 0.05, 'random_state': None}\n",
    "\n",
    "\n",
    "# Define range of k values to test kMeans clustering at\n",
    "k_range = range(1, 19)\n",
    "\n",
    "# Define range of epsilon and min_samples values to test DBSCAN clustering at\n",
    "eps_list = [0.2, 0.3, 0.4, 0.5, 0.5, 0.5, 0.7, 0.8, 0.9]\n",
    "samples_list = [3,4,5,6,7,8,9,10]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create directories to store results in\n",
    "if not os.path.exists(dir):\n",
    "    os.mkdir(dir)\n",
    "\n",
    "\n",
    "# loop through each context\n",
    "for scenario_number in range(1,19):\n",
    "    print(f\"Computing embeddings for scenario {scenario_number}\")\n",
    "    # select appropriate agent name\n",
    "    agent_name = agent_list[scenario_number-1]\n",
    "    # list containing words that will be removed from \n",
    "    words_to_remove = [agent_name, 'should', 'would', 'could']\n",
    "\n",
    "    # Strip each response for context of words_to_remove and compute SBERT embeddings\n",
    "    text_embeddings, texts_df = compute_embeddings(df_clustering, text_column_name, scenario_number, words_to_remove)\n",
    "    # reduce dimensionality of embeddings according to predetermined parameters\n",
    "    umap_embeddings = reduce_dimensions(texts_df, text_embeddings, umap_params)\n",
    "\n",
    "\n",
    "\n",
    "    ### compute DBSCAN clustering on embeddings ###\n",
    "    print(f\"Performing DBSCAN clustering on embeddings\")\n",
    "\n",
    "    for i, eps in enumerate(eps_list):\n",
    "        for j, min_samples in enumerate(samples_list):\n",
    "            # create directories to store resulting plots and tables\n",
    "            tempdir = f\"{dir}/eps{eps}_samp{min_samples}/\"\n",
    "            if not os.path.exists(tempdir):\n",
    "                os.mkdir(tempdir)\n",
    "            if not os.path.exists(tempdir + \"/plots/\"):\n",
    "                os.mkdir(tempdir + \"/plots/\")\n",
    "            if not os.path.exists(tempdir + \"/tables/\"):\n",
    "                os.mkdir(tempdir + \"/tables/\")\n",
    "\n",
    "            clustering_params = {'algorithm_name': 'DBSCAN', 'eps': eps, 'min_samples': min_samples}\n",
    "\n",
    "            # compute clusters on text embeddings\n",
    "            cluster_assignment, clustering_model = compute_clusters(umap_embeddings, clustering_params)\n",
    "\n",
    "            # store number of clusters generated by DBSCAN, not including those marked as outliers\n",
    "            num_clusters = len(set(cluster_assignment)) - (1 if -1 in cluster_assignment else 0)\n",
    "            clustering_params['num_clusters'] = num_clusters\n",
    "            \n",
    "            if num_clusters > 1:\n",
    "                # name each cluster\n",
    "                clusters = compute_centroid(texts_df, text_column_name, original_text_column_name, clustering_params, cluster_assignment, clustering_model)\n",
    "\n",
    "                # Plot clusters with the appropriate DataFrame columns\n",
    "                plot_clusters(\n",
    "                    texts_df['umap_dim2_0'],\n",
    "                    texts_df['umap_dim2_1'],\n",
    "                    texts_df['cluster'],\n",
    "                    texts_df[\"cluster_name\"],\n",
    "                    texts_df[\"response_original\"],\n",
    "                    texts_df[\"source\"],\n",
    "                    umap_params,\n",
    "                    clustering_params,\n",
    "                    f'{tempdir}plots/S{scenario_number}.html'\n",
    "                )\n",
    "            # save texts_df for this scenario to tables\n",
    "            texts_df.to_csv(f\"{tempdir}/tables/S{scenario_number}.csv\")\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
